{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "63f55b9d",
   "metadata": {},
   "source": [
    "## basics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8dc95a6",
   "metadata": {},
   "source": [
    "- data collator: collates data， prepare for the model input from raw inputs\n",
    "    - 比如长度处理, padding\n",
    "    - 更广义上来说，数据预处理；\n",
    "\n",
    "```\n",
    "# bs: 4\n",
    "1, 3, 4\n",
    "1, 7, 5, 2, 2, 4\n",
    "1, 3, 2, 4\n",
    "1, 6, 4\n",
    "\n",
    "# bs: 4, seq_len: 6\n",
    "# padding\n",
    "1, 3, 4, 0, 0, 0\n",
    "1, 7, 5, 2, 2, 4\n",
    "1, 3, 2, 4, 0, 0\n",
    "1, 6, 4, 0, 0, 0\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc05ea39",
   "metadata": {},
   "source": [
    "### 一个标准用法"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ae23828",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-15T15:40:21.685706Z",
     "start_time": "2023-08-15T15:40:21.679674Z"
    }
   },
   "source": [
    "```\n",
    "from transformers.data import DefaultDataCollator\n",
    "from transformers import Trainer\n",
    "\n",
    "data_collator = DefaultDataCollator()\n",
    "trainer = Trainer(model=model, \n",
    "                  args=args, \n",
    "                  train_dataset=dataset, \n",
    "                  data_collator=data_collator)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ed3c36d7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-15T15:49:31.415594Z",
     "start_time": "2023-08-15T15:49:31.410427Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformers.data import DefaultDataCollator\n",
    "from transformers.data import DataCollatorWithPadding\n",
    "from transformers.data import DataCollatorForTokenClassification\n",
    "from transformers.data import DataCollatorForSeq2Seq\n",
    "from transformers.data import DataCollatorForLanguageModeling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b7be2ae",
   "metadata": {},
   "source": [
    "```\n",
    "data_collator = DataCollatorWithPadding(\n",
    "    tokenizer=tokenizer\n",
    ")\n",
    "\n",
    "data_collator = DataCollatorForTokenClassification(\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "\n",
    "\n",
    "data_collator = DataCollatorForSeq2Seq(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b4dc442",
   "metadata": {},
   "source": [
    "### for seq2seq"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89087e0c",
   "metadata": {},
   "source": [
    "```\n",
    "# pair of input\n",
    "(1, 3, 4)\n",
    "(5, 3, 10)\n",
    "\n",
    "(1, 7, 5, 2, 2, 4)\n",
    "(8, 9, 3, 8, 2, 10)\n",
    "\n",
    "(1, 3, 2, 4)\n",
    "(3, 4, 8, 10)\n",
    "\n",
    "\n",
    "# inputs\n",
    "(1, 3, 4, 0, 0, 0)\n",
    "(1, 7, 5, 2, 2, 4)\n",
    "(1, 3, 2, 4, 0, 0)\n",
    "\n",
    "# labels\n",
    "(5, 3, 10, -100, -100, -100)\n",
    "(8, 9, 3, 8, 2, 10)\n",
    "(3, 4, 8, 10, -100, -100)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab1f9c83",
   "metadata": {},
   "source": [
    "### for LM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3049d19",
   "metadata": {},
   "source": [
    "```\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    model=model,\n",
    "    mlm=False\n",
    ")\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    model=model,\n",
    "    mlm=True,\n",
    "    mlm_probability=0.15\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0c591bb",
   "metadata": {},
   "source": [
    "```\n",
    "# bs: 4\n",
    "1, 3, 4\n",
    "1, 7, 5, 2, 2, 4\n",
    "1, 3, 2, 4\n",
    "1, 6, 4\n",
    "\n",
    "# bs: 4, seq_len: 6\n",
    "# padding\n",
    "1, [MASK], 4, 0, 0, 0\n",
    "1, 7, 5, [MASK], 2, 4\n",
    "1, 3, 2, [MASK], 0, 0\n",
    "1, [MASK], 4, 0, 0, 0\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
